import bisect
import itertools
from functools import cache
from heapq import heappush, heappop
from math import sqrt

from .data_struct import *


def solution_1480(nums: List[int]) -> List[int]:
    temp = 0
    ret = list(range(len(nums)))
    for i in range(len(nums)):
        temp = temp + nums[i]
        ret[i] = temp
    return ret


def solution_1342(num: int) -> int:
    count = 0
    while num > 1:
        count = count + (num & 1) + 1
        num = num >> 1
    return count + (num & 1)


def solution_1342_2(num: int) -> int:
    if num == 0:
        return 0
    # 比较有意思的解,假定32位int
    clz = 0  # 前导0的数量,用于计算num的二进制长度
    temp = num
    if temp >> 16 == 0:
        clz += 16
        temp = temp << 16
    if temp >> 24 == 0:
        clz += 8
        temp = temp << 8
    if temp >> 28 == 0:
        clz += 4
        temp = temp << 4
    if temp >> 30 == 0:
        clz += 2
        temp = temp << 2
    if temp >> 31 == 0:
        clz += 1

    temp = num  # 求num中1的个数
    temp = (temp & 0x55555555) + ((temp >> 1) & 0x55555555)  # 计算每两位的1的个数,并保存在这两位中
    temp = (temp & 0x33333333) + ((temp >> 2) & 0x33333333)  # 将刚才的计算结果每2个一组,组成4位求和,保存在这四位中
    temp = (temp & 0x0F0F0F0F) + ((temp >> 4) & 0x0F0F0F0F)  # 同上 重复
    temp = (temp & 0x00FF00FF) + ((temp >> 8) & 0x00FF00FF)
    temp = (temp & 0x0000FFFF) + ((temp >> 16) & 0x0000FFFF)
    return 32 - clz - 1 + temp


def solution_1185(day: int, month: int, year: int) -> str:
    week = ["Friday", "Saturday", "Sunday", "Monday", "Tuesday", "Wednesday", "Thursday"]
    day_of_year = [0, 365, 731, 1096]
    # 31, 28, 31, 30, 31, 30, 31,31, 30, 31, 30, 31
    day_of_week_1 = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334]
    day_of_week_2 = [0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335]
    # 1971.1.1 "Friday"
    temp_year = year - 1971
    temp_month = month - 1
    temp_day = day - 1
    total = (temp_year // 4) * 1461 + day_of_year[temp_year % 4]
    total += day_of_week_2[temp_month] if year % 4 == 0 else day_of_week_1[temp_month]
    total += temp_day
    if year == 2100 and month > 2:
        total -= 1
    return week[total % 7]


def solution_1185_2(day: int, month: int, year: int) -> str:
    # 蔡勒公式
    # w = (y + int(y/4) + int(c/4) -2*c +　int((13(m+1))/5) + d -1) mod 7
    week = ["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"]
    century = int(year / 100)
    temp_year = year - century * 100
    if month == 1 or month == 2:
        month = month + 12
        if temp_year == 0:
            temp_year = 99
            century -= 1
        else:
            temp_year = temp_year - 1
    total = temp_year + int(temp_year / 4) + int(century / 4) - 2 * century + int((26 * (month + 1)) / 10) + day - 1
    if total < 0:
        return week[(total % 7 + 7) % 7]
    else:
        return week[total % 7]


def solution_1154(date: str) -> int:
    year, month, day = date.split("-")
    year = int(year)
    month = int(month)
    day = int(day)
    day_1 = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334]
    day_2 = [0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335]
    if year % 4 == 0 and year % 100 != 0 or year % 400 == 0:
        return day_2[month - 1] + day
    else:
        return day_1[month - 1] + day


def solution_1094(trips: List[List[int]], capacity: int) -> bool:
    diff = [0] * 1001
    for trip in trips:
        diff[trip[1]] += trip[0]
        diff[trip[2]] -= trip[0]
    passengers = itertools.accumulate(diff)
    return max(passengers) <= capacity


def solution_1483():
    # lca
    # method.TreeAncestor
    ...


def solution_1261():
    # data_struct.FindElements
    ...


def solution_1004(nums: List[int], k: int) -> int:
    left = 0
    cnt = 0
    ans = 0
    for i in range(len(nums)):
        cnt += 1 - nums[i]
        while cnt > k:
            cnt -= 1 - nums[left]
            left += 1
        ans = max(ans, i - left + 1)
    return ans


def solution_1379(original: TreeNode, cloned: TreeNode, target: TreeNode) -> TreeNode:
    def dfs(node, path):
        if node == target:
            return path
        if not node.left and not node.right:
            return None
        res1, res2 = None, None
        if node.left:
            res1 = dfs(node.left, path + '0')
        if node.right:
            res2 = dfs(node.right, path + '1')
        if not res1 and not res2:
            return None
        return res1 if res1 else res2

    p = dfs(original, "")
    cur = cloned
    for x in p:
        if x == "0":
            cur = cur.left
        else:
            cur = cur.right
    return cur


def solution_1379_2(original: TreeNode, cloned: TreeNode, target: TreeNode) -> TreeNode:
    if original is None or original is target:
        return cloned
    return solution_1379_2(original.left, cloned.left, target) or solution_1379_2(original.right, cloned.right, target)


def solution_1026(root: Optional[TreeNode]) -> int:
    vals = SortedList()
    ans = 0

    def dfs(node: TreeNode):
        nonlocal ans
        ans = max(ans, abs(node.val - vals[0]), abs(node.val - vals[-1]))
        vals.add(node.val)
        if node.left:
            dfs(node.left)
        if node.right:
            dfs(node.right)
        vals.remove(node.val)

    vals.add(root.val)
    dfs(root)
    return ans


def solution_1052(customers: List[int], grumpy: List[int], minutes: int) -> int:
    ans1 = 0
    for i in range(len(customers)):
        if grumpy[i] == 0:
            ans1 += customers[i]
            customers[i] = 0
    left = 0
    ans2 = 0
    cur = 0
    for right in range(len(customers)):
        if right - left + 1 > minutes:
            cur -= customers[left]
            left += 1
        cur += customers[right]
        ans2 = max(ans2, cur)
    return ans1 + ans2


def solution_1146():
    # data_struct#SnapshotArray
    ...


def solution_1329(mat: List[List[int]]) -> List[List[int]]:
    m = len(mat)
    n = len(mat[0])
    for i in range(m):
        tmp = []
        cur = i
        for j in range(min(n, m - i)):
            tmp.append(mat[cur][j])
            cur += 1
        tmp.sort()
        cur = i
        for j in range(min(n, m - i)):
            mat[cur][j] = tmp[j]
            cur += 1
    for j in range(1, n):
        tmp = []
        cur = j
        for i in range(min(m, n - j)):
            tmp.append(mat[i][cur])
            cur += 1
        tmp.sort()
        cur = j
        for i in range(min(m, n - j)):
            mat[i][cur] = tmp[i]
            cur += 1
    return mat


def solution_1017(n: int) -> str:
    val = 0x5555_5555 ^ (0x5555_5555 - n)
    if val == 0:
        return '0'
    res = []
    while val:
        res.append(str(val & 1))
        val >>= 1
    return ''.join(res[::-1])


def solution_1491(salary: List[int]) -> float:
    min_salary, max_salary = salary[0], salary[-1]
    ans = 0
    for i in range(len(salary)):
        cur = salary[i]
        ans += cur
        if cur < min_salary:
            min_salary = cur
        if cur > max_salary:
            max_salary = cur
    return (ans - min_salary - max_salary) / (len(salary) - 2)


def solution_1235(startTime: List[int], endTime: List[int], profit: List[int]) -> int:
    # #TLE
    # jobs = zip(startTime, endTime, profit)
    # jobs = sorted(jobs, key=lambda x: x[0])
    # last = jobs[-1][0]
    #
    # @cache
    # def dfs(j):
    #     nonlocal last
    #     if j > last:
    #         return 0
    #     idx = bisect.bisect_left(jobs, j, key=lambda x: x[0])
    #     res = 0
    #     for i in range(idx, len(jobs)):
    #         res = max(res, dfs(jobs[i][1]) + jobs[i][2])
    #     return res
    #
    # return dfs(0)
    jobs = zip(startTime, endTime, profit)
    jobs = sorted(jobs, key=lambda x: x[0])

    @cache
    def dfs(i):
        """
        第i份工作选还是不选
        """
        if i >= len(jobs):
            return 0
        j = jobs[i][1]
        idx = bisect.bisect_left(jobs, j, key=lambda x: x[0])  # 开始时间最接近j的下一个工作
        return max(dfs(i + 1), dfs(idx) + jobs[i][2])  # 有单调性

    return dfs(0)


def solution_1463(grid: List[List[int]]) -> int:
    n = len(grid)
    m = len(grid[0])

    @cache
    def dfs(i, j, k):
        """
        i行，1在j列，2在k列

        dfs(i, j, k) = max(
            dfs(i+1, j-1, k), dfs(i+1, j, k), dfs(i+1,j+1,k),
            dfs(i+1, j-1, k-1), dfs(i+1, j, k-1), dfs(i+1,j+1,k-1),
            dfs(i+1, j-1, k-1), dfs(i+1, j, k-1), dfs(i+1,j+1,k-1),
        ) + val
        if j!=k :
            val = grid[i][j] + grid[i][k]
        else:
            val = grid[i][j]
        """

        if i >= n:
            return 0
        if j >= m or j < 0 or k >= m or k < 0:
            return -inf
        val = grid[i][j] + grid[i][k] if j != k else grid[i][j]
        res = max(dfs(i + 1, j - 1, k), dfs(i + 1, j, k), dfs(i + 1, j + 1, k),
                  dfs(i + 1, j - 1, k - 1), dfs(i + 1, j, k - 1), dfs(i + 1, j + 1, k - 1),
                  dfs(i + 1, j - 1, k + 1), dfs(i + 1, j, k + 1), dfs(i + 1, j + 1, k + 1),
                  )

        return res + val

    ans = dfs(0, 0, m - 1)
    dfs.cache_clear()
    return ans


def solution_1103(candies: int, num_people: int) -> List[int]:
    cnt = 0
    res = candies
    n = num_people
    while res > 0:
        res -= cnt * n * n + (n + 1) * n // 2
        cnt += 1
    if res != 0:
        cnt -= 1
        res += cnt * n * n + (n + 1) * n // 2
    ans = [0] * n
    if cnt > 0:
        for i in range(n):
            ans[i] = cnt * (i + 1) + n * (cnt - 1) * cnt // 2
    idx = 0
    while res > 0:
        if res > cnt * n + idx + 1:
            ans[idx] += cnt * n + idx + 1
            res -= cnt * n + idx + 1
            idx += 1
        else:
            ans[idx] += res
            break
    return ans


def solution_1103_2(candies: int, num_people: int) -> List[int]:
    n = num_people
    m = int((sqrt(8 * candies + 1) - 1) / 2)
    k, extra = divmod(m, n)
    ans = [k * (k - 1) // 2 * n + k * (i + 1) +
           (k * n + i + 1 if i < extra else 0)
           for i in range(n)]
    ans[extra] += candies - m * (m + 1) // 2
    return ans


def solution_1486(n: int, start: int) -> int:
    # ans = 0
    # for i in range(n):
    #     ans ^= start + 2*i
    # return ans
    # O(1) 需要看题解
    xor_n = lambda x: (x, 1, x + 1, 0)[x % 4]
    a = start // 2
    b = n & start & 1
    return (xor_n(a + n - 1) ^ xor_n(a - 1)) * 2 + b


def solution_1356(arr: List[int]) -> List[int]:
    def getOne(i):
        ans = 0
        while i > 0:
            i &= i - 1
            ans += 1
        return ans

    arr.sort(key=lambda x: (getOne(x), x))
    return arr


def solution_1318(a: int, b: int, c: int) -> int:
    t = (a | b) ^ c
    ans = 0
    for i in range(32):
        if (t >> i) & 1 == 0:
            continue
        if (c >> i) & 1 == 1:
            ans += 1
        else:
            if (a >> i) & 1:
                ans += 1
            if (b >> i) & 1:
                ans += 1
    return ans


def solution_1383(n: int, speed: List[int], efficiency: List[int], k: int) -> int:
    worker = sorted(zip(efficiency, speed), key=lambda x: (-x[0], -x[1]))
    ans = 0
    h = []
    ts = 0
    for i in range(n):
        we, ws = worker[i]
        ts += ws
        ans = max(ans, we * ts)
        heappush(h, ws)
        if len(h) == k:
            ts -= heappop(h)
    return ans % (10 ** 9 + 7)


def solution_1108(address: str) -> str:
    ans = ''
    for x in address.split('.'):
        ans += x
        ans += '[.]'
    return ans[:-3]


def solution_1186(arr: List[int]) -> int:
    mymax = lambda a, b: a if a > b else b
    ans = f0 = f1 = -inf
    for x in arr:
        f1 = mymax(f1 + x, f0)  # 注：手动 if 比大小会更快
        f0 = mymax(f0, 0) + x
        ans = mymax(ans, mymax(f0, f1))
    return ans


def solution_1334(n: int, edges: List[List[int]], distanceThreshold: int) -> int:
    mymin = lambda a, b: a if a < b else b
    w = [[inf] * n for i in range(n)]
    for x, y, wt in edges:
        w[x][y] = w[y][x] = wt
    for k in range(n):
        for i in range(n):
            for j in range(n):
                w[i][j] = mymin(w[i][j], w[i][k] + w[k][j])

    ans = 0
    min_cnt = inf
    for i in range(n):
        cnt = 0
        for j in range(n):
            if j != i and w[i][j] <= distanceThreshold:
                cnt += 1
        if cnt <= min_cnt:
            min_cnt = cnt
            ans = i
    return ans


def solution_1247(s1: str, s2: str) -> int:
    cntx, cnty = 0, 0
    n = len(s1)
    for i in range(n):
        if s1[i] == s2[i]:
            continue
        if s1[i] == 'x':
            cntx += 1
        else:
            cnty += 1
    ans = cntx // 2 + cnty // 2
    cntx = cntx & 1
    cnty = cnty & 1
    if cntx + cnty == 1:
        return -1
    return ans + cntx + cnty


def solution_1442(arr: List[int]) -> int:
    cnt, total = Counter(), Counter()
    ans = s = 0
    for k, val in enumerate(arr):
        if (t := s ^ val) in cnt:
            ans += cnt[t] * k - total[t]
        cnt[s] += 1  # 此处不是t的原因是公式中下标i对应的是前一个前缀异或和
        total[s] += k
        s = t
    return ans


def solution_1310(arr: List[int], queries: List[List[int]]) -> List[int]:
    s = [0] * (len(arr) + 1)
    for i, x in enumerate(arr):
        s[i + 1] = s[i] ^ x
    ans = [0] * len(queries)
    for i, (l, r) in enumerate(queries):
        ans[i] = s[r + 1] ^ s[l]
    return ans


def solution_1035(nums1: List[int], nums2: List[int]) -> int:
    n, m = len(nums1), len(nums2)
    f1 = defaultdict(list)
    f2 = defaultdict(list)
    for i, x in enumerate(nums1):
        f1[x].append(i)
    for i, x in enumerate(nums2):
        f2[x].append(i)

    @cache
    def dfs(i, j):
        if i == n or j == m:
            return 0
        res = 0
        # 当前 i 连线
        x = nums1[i]
        for k in f2[x]:
            if k < j:
                continue
            res = max(res, dfs(i + 1, k + 1) + 1)
        # 当前 i 不连线
        res = max(res, dfs(i + 1, j))
        return res

    return dfs(0, 0)


def solution_1450(startTime: List[int], endTime: List[int], queryTime: int) -> int:
    dif = [0] * 1002
    for s in startTime:
        dif[s] += 1
    for e in endTime:
        dif[e + 1] -= 1
    return sum(dif[:queryTime + 1])


def solution_1184(distance: List[int], start: int, destination: int) -> int:
    pre = list(itertools.accumulate(distance, initial=0))
    sum_ = pre[-1]
    if destination < start:
        start, destination = destination, start
    cur = pre[destination] - pre[start]
    return min(cur, sum_ - cur)


def solution_1014(values: List[int]) -> int:
    n = len(values)
    suf = [inf] * n
    suf[n - 1] = values[n - 1] - (n - 1)
    for i in range(n - 2, -1, -1):
        suf[i] = max(suf[i + 1], values[i] - i)
    ans = -inf
    for i in range(n - 1):
        pre = values[i] + i
        ans = max(ans, pre + suf[i + 1])
    return ans
    #
    #
    # ans = mx = 0  # mx 表示 j 左边的 values[i] + i 的最大值
    # for j, v in enumerate(values):
    #     ans = max(ans, mx + v - j)
    #     mx = max(mx, v + j)
    # return ans


def solution_1227(n: int) -> float:
    if n == 1:
        return 1
    else:
        return 0.5
    # p = 1 / n
    # for i in range(1, n - 1):
    #     p += p * 1 / (n - i)
    # return 1 - p


def solution_1436(paths: List[List[str]]) -> str:
    s = set()
    e = set()
    for x, y in paths:
        s.add(x)
        e.add(y)
    return list(e - s)[0]
